/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/runtime/contrib/json/json_runtime.h
 * \brief Utilities for json runtime.
 */

#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_

#include <tvm/ffi/extra/module.h>
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/tensor.h>

#include <cstddef>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

#include "json_node.h"

namespace tvm {
namespace runtime {
namespace json {

/*!
 * \brief A json runtime that executes the serialized JSON format. This runtime
 * can be extended by user defined runtime for execution.
 */
class JSONRuntimeBase : public ffi::ModuleObj {
 public:
  JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json,
                  const ffi::Array<ffi::String> const_names)
      : symbol_name_(symbol_name), graph_json_(graph_json), const_names_(const_names) {
    LoadGraph(graph_json_);
  }

  const char* kind() const override { return "json"; }  // May be overridden

  /*! \brief Get the property of the runtime module .*/
  int GetPropertyMask() const override {
    return ffi::Module::kBinarySerializable | ffi::Module::kRunnable;
  }

  /*! \brief Initialize a specific json runtime. */
  virtual void Init(const ffi::Array<Tensor>& consts) = 0;

  /*! \brief Invoke the execution engine to inteprete a specific json runtime. */
  virtual void Run() = 0;

  /*! \brief Does the backend support debug & profiling */
  virtual bool CanDebug() { return false; }

  /*!
   * \brief Invoke the profiler
   * \param pointer to profiler
   */
  virtual void RunProfile(profiling::Profiler* prof) {
    LOG(FATAL) << "Not expected to be here : Profiling call w/o support ?";
  }

  /*!
   * \brief Invoke the debugger
   * \return External compiler specific debug blob
   */
  virtual std::string DebugDump(void) {
    LOG(FATAL) << "Not expected to be here : Debug dump w/o support ?";
  }

  /*!
   * \brief Get a packed function.
   * \param name The name/symbol of the function.
   * \param sptr_to_self The pointer to the module node.
   * \return The packed function.
   */
  ffi::Optional<ffi::Function> GetFunction(const ffi::String& name) override {
    ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
    if (name == "get_symbol") {
      return ffi::Function(
          [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; });
    } else if (name == "get_const_vars") {
      return ffi::Function(
          [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->const_names_; });
    } else if (this->symbol_name_ == name) {
      return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
        ICHECK(this->initialized_) << "The module has not been initialized";

        // Bind argument tensors to data entries.
        this->SetInputOutputBuffers(args);

        // Execute the subgraph.
        this->Run();
      });
    } else if (this->symbol_name_ + "_debug" == name) {
      // NOTE: the current debug convention is not very compatible with
      // the FFI convention, consider clean up
      if (!this->CanDebug()) {
        return ffi::Function(nullptr);
      }
      return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
        ICHECK(this->initialized_) << "The module has not been initialized";

        // Bind argument tensors to data entries.
        this->SetInputOutputBuffers(args);

        if (auto opt_str = rv->try_cast<ffi::String>()) {
          ffi::String purpose = std::move(opt_str.value());
          if ("debug_dump" == purpose) {
            *rv = this->DebugDump();
          }
        } else {
          // Profile the subgraph.
          profiling::Profiler* prof = static_cast<profiling::Profiler*>(rv->cast<void*>());
          this->RunProfile(prof);
        }
        // ffi::String vendor_prof = this->RunProfile(prof);
      });
    } else if ("__init_" + this->symbol_name_ == name) {
      // The function to initialize constant tensors.
      return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
        ICHECK_EQ(args.size(), 1U);
        std::lock_guard<std::mutex> guard(this->initialize_mutex_);
        if (!this->initialized_) {
          this->Init(args[0].cast<ffi::Array<Tensor>>());
          this->initialized_ = true;
        }
        *rv = 0;
      });
    } else {
      return std::nullopt;
    }
  }

  ffi::Bytes SaveToBytes() const override {
    std::string buffer;
    dmlc::MemoryStringStream ms(&buffer);
    dmlc::Stream* stream = &ms;
    // Save the symbol
    stream->Write(symbol_name_);
    // Save the graph
    stream->Write(graph_json_);
    // Save the required const names
    std::vector<std::string> consts;
    for (const auto& it : const_names_) {
      consts.push_back(it);
    }
    stream->Write(consts);
    return ffi::Bytes(buffer);
  }

  template <typename T,
            typename = typename std::enable_if<std::is_base_of<JSONRuntimeBase, T>::value>::type>
  static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) {
    dmlc::MemoryFixedSizeStream ms(const_cast<char*>(bytes.data()), bytes.size());
    dmlc::Stream* stream = &ms;
    std::string symbol;
    std::string graph_json;
    std::vector<std::string> consts;
    // Load the symbol
    ICHECK(stream->Read(&symbol)) << "Loading symbol name failed";
    ICHECK(stream->Read(&graph_json)) << "Loading graph json failed";
    ICHECK(stream->Read(&consts)) << "Loading the const name list failed";
    ffi::Array<ffi::String> const_names;
    for (const auto& it : consts) {
      const_names.push_back(it);
    }
    auto n = ffi::make_object<T>(symbol, graph_json, const_names);
    return ffi::Module(n);
  }

  /*!
   * \brief Get the JSON generated by codegen.
   *
   * \param format the format to return.
   * \return A string of JSON.
   */
  ffi::String InspectSource(const ffi::String& format) const override { return graph_json_; }

 protected:
  /*!
   * \brief Set up the input and output buffers by binding their DLTensor pointers to the
   * corresponding data entry.
   *
   * \param args The packed args.
   */
  void SetInputOutputBuffers(const ffi::PackedArgs& args) {
    ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
        << "Found mismatch in the number of provided data entryies and required.";

    for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
      auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
                                           : EntryID(outputs_[i - input_var_eid_.size()]);

      const DLTensor* arg;
      if (auto opt_nd = args[i].as<Tensor>()) {
        Tensor arr = opt_nd.value();
        arg = arr.operator->();
      } else {
        arg = args[i].cast<DLTensor*>();
      }

      // Assign input/output the Tensor pointers to data entry so that we can directly
      // read/write host buffers.
      data_entry_[eid] = arg;
    }
  }

  /*!
   * \brief Load the graph and record the entries for inputs and constants.
   *
   * \param graph_json The graph in the json format.
   */
  void LoadGraph(const std::string& graph_json) {
    std::istringstream is(graph_json);
    dmlc::JSONReader reader(&is);
    this->Load(&reader);
    std::vector<std::string> consts;
    for (size_t i = 0; i < input_nodes_.size(); i++) {
      uint32_t nid = input_nodes_[i];
      std::string name = nodes_[nid].name_;
      if (nodes_[nid].op_type_ == "input") {
        ICHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size());
        for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
          input_var_eid_.push_back(EntryID(nid, j));
        }
        nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size());
      } else {
        ICHECK_EQ(nodes_[nid].op_type_, "const");
        auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
        ICHECK(pos != std::end(const_names_)) << "Found non-existent constant: " << name;
        const_idx_.push_back(nid);
        consts.push_back(name);
      }
    }
    ICHECK_EQ(consts.size(), const_names_.size())
        << "Found mismatch for the number of constants in the graph and required.";

    for (size_t i = 0; i < consts.size(); i++) {
      ICHECK_EQ(consts[i], const_names_[i])
          << "The position of constant in the graph must be the same as the required.";
    }

    // Reserve data entries.
    data_entry_.resize(NumEntries());
  }

  /*!
   * \brief Set up the constants/weights for inference by binding their DLTensor pointer to
   * the corresponding data entry.
   *
   * \param consts A list of constant Tensor to be used.
   */
  void SetupConstants(const ffi::Array<Tensor>& consts) {
    for (size_t i = 0; i < consts.size(); ++i) {
      data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->();
    }
  }

  // Load the graph.
  void Load(dmlc::JSONReader* reader) {
    reader->BeginObject();
    std::string key;
    std::string symbol_;
    while (reader->NextObjectItem(&key)) {
      if (key == "nodes") {
        reader->Read(&nodes_);
      } else if (key == "arg_nodes") {
        reader->Read(&input_nodes_);
      } else if (key == "node_row_ptr") {
        reader->Read(&node_row_ptr_);
      } else if (key == "heads") {
        reader->Read(&outputs_);
      } else if (key == "symbol") {
        reader->Read(&symbol_);
      } else {
        LOG(FATAL) << "Unknown key: " << key;
      }
    }
  }

  // Get the node entry index.
  uint32_t EntryID(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }

  // Get the node entry index.
  uint32_t EntryID(const JSONGraphNodeEntry& e) const { return EntryID(e.id_, e.index_); }

  // Number of node entries.
  uint32_t NumEntries() const { return node_row_ptr_.back(); }

 protected:
  /*! \brief The only subgraph name for this module. */
  std::string symbol_name_;
  /*! \brief The graph. */
  std::string graph_json_;
  /*! \brief The required constant names. */
  ffi::Array<ffi::String> const_names_;
  /*! \brief The json graph nodes. */
  std::vector<JSONGraphNode> nodes_;
  /*! \brief The input nodes, including variables and constants. */
  std::vector<uint32_t> input_nodes_;
  /*! \brief Used for quick entry indexing. */
  std::vector<uint32_t> node_row_ptr_;
  /*! \brief Output entries. */
  std::vector<JSONGraphNodeEntry> outputs_;
  /*! \brief Data of that entry. */
  std::vector<const DLTensor*> data_entry_;
  /*! \brief Map the input name to entry id. */
  std::vector<uint32_t> input_var_eid_;
  /*! \brief input const node index. */
  std::vector<uint32_t> const_idx_;
  /*! \brief Indicate if the engine has been initialized. */
  bool initialized_{false};
  /*! \brief Initializer mutex*/
  std::mutex initialize_mutex_;
};

}  // namespace json
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
